import torch
def get_dataloader(args, trainset, testset, rank):
    train_sampler = None
    test_sampler = None
    if rank is not None:
        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=args.world_size, shuffle=True, rank=rank)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset, num_replicas=args.world_size, rank=rank)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_per_gpu, shuffle=(rank is None), num_workers=args.num_workers_per_gpu, pin_memory=True, sampler=train_sampler)
    testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers_per_gpu, pin_memory=True, sampler=test_sampler)
    return trainloader, testloader

def get_nonsampler_trainloader(args, trainset):
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_per_gpu, num_workers=args.num_workers_per_gpu, pin_memory=True)
    return trainloader